# -*- coding:utf-8 -*-

import numpy as np
from config.glob.global_pool import global_pool

"""
drnn 预测函数
"""


def pick_top_n(preds, vocab_size2, top_n=5):
    """
    挑出前n个
    :param preds:
    :param vocab_size:
    :param top_n:
    :return:
    """
    p = np.squeeze(preds)
    # 除了top_n个预测值,其他的位置都置为0
    p[np.argsort(p)[:-top_n]] = 0
    # 归一化概率
    p = p / np.sum(p)
    # 随机选取一个字符
    c = np.random.choice(vocab_size2, 1, p=p)[0]
    return c


def predict(self, prime):
    """
    drnn预测函数
    :return:
    """
    embedding = global_pool.embedding
    prime = embedding.text_to_arr(prime)
    samples = [c for c in prime]
    new_state = self.sess.run(self.initial_state)
    preds = np.ones(embedding.vocab_size)  # for prime=[]
    for c in prime:
        x = np.zeros((1, 1))
        # 输入单个字符
        x[0, 0] = c
        feed = {self.xs: x, self.initial_state: new_state}
        preds, new_state = self.sess.run([self.proba_prediction, self.final_state], feed_dict=feed)

    c = pick_top_n(preds, embedding.vocab_size)
    # 添加字符到samples中
    samples.append(c)

    # 不断生成字符，直到达到指定数目
    for i in range(global_pool.config.sentence_len):
        x = np.zeros((1, 1))
        x[0, 0] = c
        feed = {self.xs: x, self.initial_state: new_state}
        preds, new_state = self.sess.run([self.proba_prediction, self.final_state], feed_dict=feed)
        c = pick_top_n(preds, embedding.vocab_size)
        samples.append(c)
    y_preds = np.array(samples)
    y_preds = embedding.arr_to_text(y_preds)
    return y_preds
